Variational Inference with Implicit Approximate Inference Models (WIP Pt. 3)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
Using TensorFlow backend.
In [3]:
plt.style.use('seaborn-notebook')
# display animation inline
plt.rc('animation', html='html5')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
D_BATCH_SIZE = 128
G_BATCH_SIZE = 128
PRIOR_VARIANCE = 2.

Bayesian Logistic Regression (Synthetic Data)

In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[11]:
(300, 300)
In [12]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [13]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([- .5, -1.])
In [14]:
X = np.vstack((x1, x2, x3))
X.shape
Out[14]:
(3, 2)
In [15]:
y1 = 1
y2 = 1
y3 = 0
In [16]:
y = np.stack((y1, y2, y3))
y.shape
Out[16]:
(3,)
In [17]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [18]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[18]:
(300, 300, 3)
In [19]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [20]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 140638038322816 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)140638038320800 dense_1: Denseinput:output:(None, 2)(None, 10)140638038322816->140638038320800 140638038322592 dense_2: Denseinput:output:(None, 10)(None, 20)140638038320800->140638038322592 140638038321808 logit: Denseinput:output:(None, 20)(None, 1)140638038322592->140638038321808 140638038360976 activation_1: Activationinput:output:(None, 1)(None, 1)140638038321808->140638038360976
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.48928388953208923, 0.80000001192092896]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 140638039992360 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)140638037305216 dense_3: Denseinput:output:(None, 3)(None, 10)140638039992360->140638037305216 140638037305048 dense_4: Denseinput:output:(None, 10)(None, 20)140638037305216->140638037305048 140638039956000 dense_5: Denseinput:output:(None, 20)(None, 2)140638037305048->140638039956000
In [31]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [32]:
w_posterior_samples = inference.predict(eps)
w_posterior_samples.shape
Out[32]:
(128, 2)
In [33]:
w_prior_samples = prior.rvs(size=BATCH_SIZE)
w_prior_samples.shape
Out[33]:
(128, 2)
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*w_posterior_samples.T, alpha=.6)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
fig, ax = plt.subplots(figsize=(5, 5))

w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma, animate=True)

scatter_posterior = ax.scatter(*w_posterior_samples.T, alpha=.8)
scatter_prior = ax.scatter(*w_prior_samples.T, alpha=.8)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

t = ax.text(0.05, 0.85, 'step: 0', 
            transform=ax.transAxes, bbox=props)

plt.show()
Discriminator pre-training
In [36]:
def train_animate(epoch_num, batch_size=128, steps_per_epoch=20):

    for step in range(steps_per_epoch):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    ax.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

    ax.scatter(*w_sample_posterior.T, alpha=.8)
    ax.scatter(*w_sample_prior.T, alpha=.8)

    train_info = dict(zip(discriminator.metrics_names, metrics))
    train_info['epoch'] = epoch_num
    
    props = dict(boxstyle='round', facecolor='w', alpha=0.5)

    ax.text(0.05, 0.05, 
            ('epoch: {epoch:2d}\n'
             'accuracy: {binary_accuracy:.2f}\n'        
             'loss: {loss:.2f}').format(**train_info), 
            transform=ax.transAxes, bbox=props)

    ax.set_xlabel('$w_1$')
    ax.set_ylabel('$w_2$')

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    return ax
In [37]:
FuncAnimation(fig, train_animate, frames=50, 
              interval=200, # 5 fps
              blit=False)
Out[37]:
In [ ]:
 

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 2)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
Using TensorFlow backend.
/Users/tiao/.virtualenvs/implicit/lib/python3.6/site-packages/IPython/html.py:14: ShimWarning: The `IPython.html` package has been deprecated since IPython 4.0. You should import from `notebook` instead. `IPython.html.widgets` has moved to `ipywidgets`.
  "`IPython.html.widgets` has moved to `ipywidgets`.", ShimWarning)
In [3]:
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
# display animation inline
plt.rc('animation', html='html5')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
D_BATCH_SIZE = 128
G_BATCH_SIZE = 128
PRIOR_VARIANCE = 2.

Bayesian Logistic Regression (Synthetic Data)

In [86]:
w_min, w_max = -5, 5
In [87]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [88]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[88]:
(300, 300, 2)
In [89]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [90]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[90]:
(300, 300)
In [91]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [92]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([- .5, -1.])
In [93]:
X = np.vstack((x1, x2, x3))
X.shape
Out[93]:
(3, 2)
In [94]:
y1 = 1
y2 = 1
y3 = 0
In [95]:
y = np.stack((y1, y2, y3))
y.shape
Out[95]:
(3,)
In [96]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [97]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[97]:
(300, 300, 3)
In [98]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [104]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [105]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.plot(*np.vstack((x1,x2,x3)).T, 'ro')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [106]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [107]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [108]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[108]:
G 4390601896 dense_6_input: InputLayerinput:output:(None, 2)(None, 2)4656746112 dense_6: Denseinput:output:(None, 2)(None, 10)4390601896->4656746112 4656743088 dense_7: Denseinput:output:(None, 10)(None, 20)4656746112->4656743088 4705728888 logit: Denseinput:output:(None, 20)(None, 1)4656743088->4705728888 4705708240 activation_2: Activationinput:output:(None, 1)(None, 1)4705728888->4705708240
In [109]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [110]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [111]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[111]:
[0.45842784643173218, 1.0]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [112]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_8 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_9 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_10 (Dense)             (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [113]:
phi = inference.trainable_weights
phi
Out[113]:
[<tf.Variable 'dense_8/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_8/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_9/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_9/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_10/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_10/bias:0' shape=(2,) dtype=float32_ref>]
In [114]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[114]:
G 4638359224 dense_8_input: InputLayerinput:output:(None, 3)(None, 3)4706841152 dense_8: Denseinput:output:(None, 3)(None, 10)4638359224->4706841152 4638452536 dense_9: Denseinput:output:(None, 10)(None, 20)4706841152->4638452536 4638609928 dense_10: Denseinput:output:(None, 20)(None, 2)4638452536->4638609928
In [115]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [116]:
w_posterior_samples = inference.predict(eps)
w_posterior_samples.shape
Out[116]:
(128, 2)
In [117]:
w_prior_samples = prior.rvs(size=BATCH_SIZE)
w_prior_samples.shape
Out[117]:
(128, 2)
In [120]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*w_posterior_samples.T, alpha=.6)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [122]:
fig, ax = plt.subplots(figsize=(5, 5))

w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma, animate=True)

scatter_posterior = ax.scatter(*w_posterior_samples.T, alpha=.8)
scatter_prior = ax.scatter(*w_prior_samples.T, alpha=.8)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

t = ax.text(0.05, 0.85, 'step: 0', 
            transform=ax.transAxes, bbox=props)

plt.show()
Discriminator pre-training
In [164]:
def prior_samples_gen(batch_size):

    while True:
        yield prior.rvs(size=batch_size)
In [165]:
def posterior_samples_gen(inference_model, batch_size):

    while True:
        eps = np.random.randn(batch_size, NOISE_DIM)
        yield inference_model.predict(eps)
In [172]:
def discriminator_data_gen(inference_model, batch_size):
    
    for samples_prior, samples_posterior in zip(prior_samples_gen(batch_size), 
                                                posterior_samples_gen(inference_model, batch_size)):
        inputs = np.vstack((samples_prior, samples_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))
        yield inputs, targets
In [184]:
h = discriminator.fit_generator(generator=discriminator_data_gen(inference, 128), steps_per_epoch=32, epochs=2)
Epoch 1/2
32/32 [==============================] - 0s - loss: 0.1257 - binary_accuracy: 0.9559     
Epoch 2/2
32/32 [==============================] - 0s - loss: 0.1169 - binary_accuracy: 0.9598     
In [185]:
h.history['loss'][-1]
Out[185]:
0.11688009253703058
In [167]:
metrics = discriminator.train_on_batch(D_input, D_labels)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-167-0dfc15023395> in <module>()
----> 1 metrics = discriminator.train_on_batch(D_input, D_labels)

NameError: name 'D_input' is not defined
In [36]:
def animate(step):

    ax.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

    info_dict = dict(zip(discriminator.metrics_names, metrics))
    info_dict['step'] = step

    props = dict(boxstyle='round', facecolor='w', alpha=0.5)

    t = ax.text(0.05, 0.85, 'step: 0', 
                transform=ax.transAxes, bbox=props)

    scatter_posterior = ax.scatter(*w_posterior_samples.T, alpha=.8)
    scatter_prior = ax.scatter(*w_prior_samples.T, alpha=.8)
    
    return ax
In [37]:
FuncAnimation(fig, animate, frames=50, 
              interval=200, # 5 fps
              blit=False)
Out[37]:
In [38]:
fig, ax = plt.subplots(figsize=(7, 7))

w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

cset = ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma, animate=True)

scatter_posterior = ax.scatter(*w_posterior_samples.T, alpha=.8)
scatter_prior = ax.scatter(*w_prior_samples.T, alpha=.8)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

t = ax.text(0.05, 0.85, 'step: 0', 
            transform=ax.transAxes, bbox=props)

plt.show()
In [39]:
cset.collections
Out[39]:
<a list of 8 mcoll.PathCollection objects>
In [40]:
from matplotlib.collections import PatchCollection
In [ ]:
 
In [41]:
dir(cset.collections[0])
Out[41]:
['_A',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_agg_filter',
 '_alpha',
 '_animated',
 '_antialiaseds',
 '_axes',
 '_bcast_lwls',
 '_clipon',
 '_clippath',
 '_contains',
 '_edge_default',
 '_edgecolors',
 '_facecolors',
 '_factor',
 '_get_bool',
 '_get_value',
 '_gid',
 '_hatch',
 '_hatch_color',
 '_is_filled',
 '_is_stroked',
 '_label',
 '_linestyles',
 '_linewidths',
 '_mouseover',
 '_offset_position',
 '_offsets',
 '_oid',
 '_original_edgecolor',
 '_original_facecolor',
 '_path_effects',
 '_paths',
 '_picker',
 '_pickradius',
 '_prepare_points',
 '_prop_order',
 '_propobservers',
 '_rasterized',
 '_remove_method',
 '_set_edgecolor',
 '_set_facecolor',
 '_set_gc_clip',
 '_sizes',
 '_sketch',
 '_snap',
 '_stale',
 '_sticky_edges',
 '_transOffset',
 '_transform',
 '_transformSet',
 '_transforms',
 '_uniform_offsets',
 '_url',
 '_urls',
 '_us_linestyles',
 '_us_lw',
 '_visible',
 'add_callback',
 'add_checker',
 'aname',
 'autoscale',
 'autoscale_None',
 'axes',
 'callbacksSM',
 'changed',
 'check_update',
 'clipbox',
 'cmap',
 'colorbar',
 'contains',
 'convert_xunits',
 'convert_yunits',
 'draw',
 'eventson',
 'figure',
 'findobj',
 'format_cursor_data',
 'get_agg_filter',
 'get_alpha',
 'get_animated',
 'get_array',
 'get_axes',
 'get_children',
 'get_clim',
 'get_clip_box',
 'get_clip_on',
 'get_clip_path',
 'get_cmap',
 'get_contains',
 'get_cursor_data',
 'get_dashes',
 'get_datalim',
 'get_edgecolor',
 'get_edgecolors',
 'get_facecolor',
 'get_facecolors',
 'get_figure',
 'get_fill',
 'get_gid',
 'get_hatch',
 'get_label',
 'get_linestyle',
 'get_linestyles',
 'get_linewidth',
 'get_linewidths',
 'get_offset_position',
 'get_offset_transform',
 'get_offsets',
 'get_path_effects',
 'get_paths',
 'get_picker',
 'get_pickradius',
 'get_rasterized',
 'get_sizes',
 'get_sketch_params',
 'get_snap',
 'get_transform',
 'get_transformed_clip_path_and_affine',
 'get_transforms',
 'get_url',
 'get_urls',
 'get_visible',
 'get_window_extent',
 'get_zorder',
 'have_units',
 'hitlist',
 'is_figure_set',
 'is_transform_set',
 'mouseover',
 'norm',
 'pchanged',
 'pick',
 'pickable',
 'properties',
 'remove',
 'remove_callback',
 'set',
 'set_agg_filter',
 'set_alpha',
 'set_animated',
 'set_antialiased',
 'set_antialiaseds',
 'set_array',
 'set_axes',
 'set_clim',
 'set_clip_box',
 'set_clip_on',
 'set_clip_path',
 'set_cmap',
 'set_color',
 'set_contains',
 'set_dashes',
 'set_edgecolor',
 'set_edgecolors',
 'set_facecolor',
 'set_facecolors',
 'set_figure',
 'set_gid',
 'set_hatch',
 'set_label',
 'set_linestyle',
 'set_linestyles',
 'set_linewidth',
 'set_linewidths',
 'set_lw',
 'set_norm',
 'set_offset_position',
 'set_offsets',
 'set_path_effects',
 'set_paths',
 'set_picker',
 'set_pickradius',
 'set_rasterized',
 'set_sizes',
 'set_sketch_params',
 'set_snap',
 'set_transform',
 'set_url',
 'set_urls',
 'set_visible',
 'set_zorder',
 'stale',
 'stale_callback',
 'sticky_edges',
 'to_rgba',
 'update',
 'update_dict',
 'update_from',
 'update_scalarmappable',
 'zorder']
In [42]:
import matplotlib.patches as patches
In [43]:
from matplotlib.collections import PathCollection
In [44]:
fig, ax = plt.subplots(figsize=(7, 7))

ax.add_collection(cset.collections[4])

plt.show()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-44-d9b87f692bb4> in <module>()
      1 fig, ax = plt.subplots(figsize=(7, 7))
      2 
----> 3 ax.add_collection(cset.collections[4])
      4 
      5 plt.show()

~/.virtualenvs/implicit/lib/python3.6/site-packages/matplotlib/axes/_base.py in add_collection(self, collection, autolim)
   1754             collection.set_label('_collection%d' % len(self.collections))
   1755         self.collections.append(collection)
-> 1756         self._set_artist_props(collection)
   1757 
   1758         if collection.get_clip_path() is None:

~/.virtualenvs/implicit/lib/python3.6/site-packages/matplotlib/axes/_base.py in _set_artist_props(self, a)
    922     def _set_artist_props(self, a):
    923         """set the boilerplate props for artists added to axes"""
--> 924         a.set_figure(self.figure)
    925         if not a.is_transform_set():
    926             a.set_transform(self.transData)

~/.virtualenvs/implicit/lib/python3.6/site-packages/matplotlib/artist.py in set_figure(self, fig)
    648         # to more than one Axes
    649         if self.figure is not None:
--> 650             raise RuntimeError("Can not put single artist in "
    651                                "more than one figure")
    652         self.figure = fig

RuntimeError: Can not put single artist in more than one figure
In [134]:
fig, ax = plt.subplots(figsize=(7, 7))

cset = ax.contourf(np.linspace(-3, 3, 32), np.linspace(-3, 3, 32), np.random.randn(32, 32), cmap='magma')
scat = ax.scatter(*np.random.randn(2, 128), alpha=.8)

ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)

plt.show()
In [138]:
def animate(step):

    ax.cla()
    
    ax.contourf(np.linspace(-3, 3, 32), np.linspace(-3, 3, 32), np.random.randn(32, 32), cmap='magma')
    ax.scatter(*np.random.randn(2, 128), alpha=.8)

    ax.set_xlim(-3, 3)
    ax.set_ylim(-3, 3)
    
    return scat
In [139]:
FuncAnimation(fig, animate, frames=25, 
              interval=200) # 5 fps
Out[139]:
In [ ]:
 

Variational Inference with Implicit Approximate Inference Models (WIP)

In [69]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [70]:
import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import keras.backend as K

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit
from mpl_toolkits.mplot3d import Axes3D

from IPython.display import SVG
In [71]:
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
In [74]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [75]:
K.tf.__version__
Out[75]:
'1.2.1'
In [76]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
D_BATCH_SIZE = 128
G_BATCH_SIZE = 128
PRIOR_VARIANCE = 2.
In [77]:
w_min, w_max = -5, 5
In [78]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [79]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[79]:
(300, 300, 2)
In [80]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [81]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[81]:
(300, 300)
In [82]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [83]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([- .5, -1.])
In [84]:
X = np.vstack((x1, x2, x3))
X.shape
Out[84]:
(3, 2)
In [85]:
y1 = 1
y2 = 1
y3 = 0
In [86]:
y = np.stack((y1, y2, y3))
y.shape
Out[86]:
(3,)
In [87]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [88]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[88]:
(300, 300, 3)
In [89]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [90]:
fig, ax = plt.subplots(figsize=(6, 5))

c = ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.colorbar(c)
plt.show()
In [91]:
fig, ax = plt.subplots(figsize=(6, 5))

c = ax.contourf(w1, w2, 
                np.exp(log_prior+np.sum(llhs, axis=2)), 
                cmap=plt.cm.magma)
ax.plot(*np.vstack((x1,x2,x3)).T, 'ro')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.colorbar(c)
plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [92]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
                      loss='binary_crossentropy')
In [93]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [94]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[94]:
G 140717329860592 dense_6_input: InputLayerinput:output:(None, 2)(None, 2)140717330379888 dense_6: Denseinput:output:(None, 2)(None, 10)140717329860592->140717330379888 140717330379944 dense_7: Denseinput:output:(None, 10)(None, 20)140717330379888->140717330379944 140717329859024 logit: Denseinput:output:(None, 20)(None, 1)140717330379944->140717329859024 140717329765656 activation_2: Activationinput:output:(None, 1)(None, 1)140717329859024->140717329765656
In [95]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [96]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [97]:
discriminator.evaluate(prior.rvs(size=5), np.ones(5))
5/5 [==============================] - 0s
Out[97]:
0.60724371671676636

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [98]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_8 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_9 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_10 (Dense)             (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [99]:
phi = inference.trainable_weights
phi
Out[99]:
[<tf.Variable 'dense_8/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_8/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_9/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_9/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_10/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_10/bias:0' shape=(2,) dtype=float32_ref>]
In [100]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[100]:
G 140717330378936 dense_8_input: InputLayerinput:output:(None, 3)(None, 3)140717330676480 dense_8: Denseinput:output:(None, 3)(None, 10)140717330378936->140717330676480 140717329573760 dense_9: Denseinput:output:(None, 10)(None, 20)140717330676480->140717329573760 140718031384192 dense_10: Denseinput:output:(None, 20)(None, 2)140717329573760->140718031384192
In [101]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [102]:
w_posterior_samples = inference.predict(eps)
w_posterior_samples.shape
Out[102]:
(128, 2)
In [103]:
w_prior_samples = prior.rvs(size=BATCH_SIZE)
w_prior_samples.shape
Out[103]:
(128, 2)
In [104]:
fig, ax = plt.subplots(figsize=(6, 5))

c = ax.contourf(w1, w2, 
                np.exp(log_prior+np.sum(llhs, axis=2)), 
                cmap=plt.cm.magma)

ax.scatter(*w_posterior_samples.T, alpha=.6)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.colorbar(c)
plt.show()
In [105]:
fig, ax = plt.subplots(figsize=(5, 5))

w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)
ax.scatter(*w_posterior_samples.T, alpha=.6)
ax.scatter(*w_prior_samples.T, alpha=.6)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [106]:
sess = K.get_session()
In [107]:
w_prior_samples = K.random_normal(shape=(BATCH_SIZE, LATENT_DIM), 
                                  stddev=np.sqrt(PRIOR_VARIANCE))
In [108]:
eps = K.random_normal(shape=(BATCH_SIZE, NOISE_DIM))
In [109]:
w_posterior_samples = inference(eps)
w_posterior_samples
Out[109]:
<tf.Tensor 'sequential_2/dense_10/BiasAdd:0' shape=(128, 2) dtype=float32>
In [110]:
fig, ax = plt.subplots(figsize=(5, 5))

w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)
ax.scatter(*sess.run(w_posterior_samples).T, alpha=.6)
ax.scatter(*sess.run(w_prior_samples).T, alpha=.6)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [111]:
discrim_loss = K.mean(
    K.binary_crossentropy(
        discriminator(w_posterior_samples), 
        K.ones_like(discriminator(w_posterior_samples))) +
    K.binary_crossentropy(
        discriminator(w_prior_samples), 
        K.zeros_like(discriminator(w_prior_samples))))
In [112]:
discrim_loss.eval(session=sess)
Out[112]:
1.385956
In [113]:
opt = K.tf.train.AdamOptimizer(3e-3, beta1=0.9)
In [114]:
discrim_train_op = opt.minimize(discrim_loss, 
                                var_list=discriminator.trainable_weights)
In [115]:
K.mean(ratio_estimator(w_posterior_samples))
Out[115]:
<tf.Tensor 'Mean_14:0' shape=() dtype=float32>
In [116]:
K.expand_dims(K.constant(y), 1)
Out[116]:
<tf.Tensor 'ExpandDims_8:0' shape=(3, 1) dtype=float32>
In [117]:
K.pow(K.constant(-1), 1-K.expand_dims(K.constant(y), 1))
Out[117]:
<tf.Tensor 'Pow_7:0' shape=(3, 1) dtype=float32>
In [118]:
K.dot(K.constant(X), K.transpose(w_posterior_samples))
Out[118]:
<tf.Tensor 'MatMul_7:0' shape=(3, 128) dtype=float32>
In [119]:
K.dot(K.constant(X), K.transpose(w_posterior_samples))*K.pow(K.constant(-1), 1-K.expand_dims(K.constant(y), 1))
Out[119]:
<tf.Tensor 'mul_10:0' shape=(3, 128) dtype=float32>
In [120]:
K.sigmoid(K.dot(K.constant(X), K.transpose(w_posterior_samples)) *
          K.pow(K.constant(-1), 1-K.expand_dims(K.constant(y), 1)))
Out[120]:
<tf.Tensor 'Sigmoid_5:0' shape=(3, 128) dtype=float32>
In [121]:
K.log(K.sigmoid(K.dot(K.constant(X), K.transpose(w_posterior_samples)) *
                K.pow(K.constant(-1), 1-K.expand_dims(K.constant(y), 1))))
Out[121]:
<tf.Tensor 'Log_10:0' shape=(3, 128) dtype=float32>
In [122]:
K.mean(K.log(K.sigmoid(K.dot(K.constant(X), K.transpose(w_posterior_samples)) *
                       K.pow(K.constant(-1), 1-K.expand_dims(K.constant(y), 1)))))
Out[122]:
<tf.Tensor 'Mean_15:0' shape=() dtype=float32>
In [123]:
log_likelihood = K.mean(K.log(K.sigmoid(K.dot(K.constant(X), K.transpose(w_posterior_samples)) *
                                        K.pow(K.constant(-1), 1-K.expand_dims(K.constant(y), 1)))))
log_likelihood
Out[123]:
<tf.Tensor 'Mean_16:0' shape=() dtype=float32>
In [124]:
inference_loss = K.mean(ratio_estimator(w_posterior_samples)) - log_likelihood
In [125]:
inference_loss.eval(session=sess)
Out[125]:
0.62940556
In [126]:
inference_train_op = opt.minimize(inference_loss, 
                                  var_list=inference.trainable_weights)
In [127]:
keras_llh = K.reshape(K.sum(K.log(K.sigmoid(K.dot(K.constant(X), K.reshape(K.permute_dimensions(K.constant(w_grid), (2, 0, 1)), shape=(2, 300*300))) *
                                            K.pow(K.constant(-1), 1-K.expand_dims(K.constant(y), 1)))), 
                            axis=0), 
                      shape=(300, 300))
keras_llh
Out[127]:
<tf.Tensor 'Reshape_3:0' shape=(300, 300) dtype=float32>
In [128]:
fig, ax = plt.subplots(figsize=(6, 5))

c = ax.contourf(w1, w2, sess.run(keras_llh), cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.colorbar(c)
plt.show()
In [129]:
np.allclose(sess.run(keras_llh), 
            np.sum(llhs, axis=2))
Out[129]:
True
In [130]:
sess = K.get_session()
In [131]:
for d_step in range(3*300):
    loss, _ = sess.run([discrim_loss, discrim_train_op])
    print(loss)
1.38532
1.3579
1.37045
1.32634
1.33504
1.27227
1.28475
1.24287
1.22682
1.22881
1.21535
1.19214
1.15849
1.14416
1.14371
1.14085
1.09653
1.1075
1.10195
1.09052
1.05989
1.08887
1.04995
1.05183
1.05929
1.03833
1.01648
1.03497
1.01955
0.960189
0.988754
0.951096
1.00077
0.951274
0.980649
0.966437
0.972831
0.949386
0.957259
0.926495
0.937377
0.916226
0.937085
0.914591
0.889111
0.904115
0.874599
0.862437
0.820481
0.895674
0.840388
0.906279
0.875646
0.839755
0.848321
0.792723
0.82651
0.818164
0.811942
0.773857
0.748018
0.803157
0.784107
0.763402
0.749615
0.775418
0.760771
0.789065
0.773577
0.716524
0.741916
0.747315
0.7378
0.732429
0.715211
0.701757
0.66048
0.696592
0.67441
0.715436
0.752388
0.672804
0.655467
0.693615
0.662632
0.660639
0.670306
0.625209
0.672787
0.661013
0.612875
0.636962
0.598916
0.662136
0.55278
0.613553
0.58413
0.567281
0.625276
0.573121
0.602104
0.579841
0.568854
0.565051
0.569442
0.54241
0.546366
0.538198
0.552248
0.601788
0.509559
0.536433
0.548869
0.538587
0.507484
0.508999
0.503613
0.476276
0.499506
0.508605
0.467023
0.461849
0.481012
0.472659
0.460581
0.45003
0.499749
0.454548
0.449862
0.536963
0.448475
0.467067
0.448574
0.486148
0.452766
0.437414
0.456874
0.459249
0.422732
0.420728
0.440434
0.461794
0.42356
0.463026
0.333407
0.433452
0.452114
0.35519
0.38314
0.367164
0.404151
0.342661
0.333128
0.346002
0.371624
0.452988
0.459696
0.428007
0.354236
0.311349
0.367938
0.462415
0.324165
0.452471
0.370054
0.333621
0.340891
0.447923
0.442891
0.336959
0.304018
0.333281
0.329424
0.321482
0.336908
0.375099
0.281343
0.3104
0.393116
0.393949
0.254872
0.444985
0.397345
0.354726
0.345908
0.228596
0.398022
0.353818
0.270504
0.336221
0.404845
0.321817
0.377144
0.444702
0.349187
0.339988
0.323128
0.402443
0.300859
0.288087
0.25431
0.299444
0.255488
0.313054
0.389661
0.230185
0.415969
0.392863
0.321102
0.322946
0.318168
0.248868
0.430951
0.271098
0.340435
0.342456
0.25684
0.340246
0.372529
0.45213
0.223858
0.288524
0.313052
0.232402
0.330251
0.378014
0.263161
0.252301
0.291386
0.300431
0.262314
0.434332
0.367387
0.252038
0.244233
0.402506
0.345745
0.255478
0.262207
0.311652
0.336827
0.268766
0.364561
0.325013
0.393565
0.361726
0.357331
0.233737
0.320791
0.228293
0.282431
0.288983
0.407007
0.337941
0.304532
0.271827
0.267347
0.362266
0.304938
0.265404
0.255674
0.355238
0.406194
0.256441
0.314476
0.290617
0.361067
0.230431
0.231442
0.214638
0.27529
0.276817
0.288087
0.33216
0.302738
0.240562
0.342358
0.321828
0.431387
0.250847
0.27855
0.244946
0.269892
0.392373
0.289954
0.280752
0.379365
0.308509
0.380273
0.2464
0.327529
0.320829
0.293981
0.174126
0.365427
0.220253
0.333885
0.271242
0.308914
0.276282
0.266287
0.311934
0.348004
0.275128
0.207578
0.413157
0.224726
0.226337
0.253312
0.241829
0.315702
0.214764
0.326352
0.386239
0.291657
0.272628
0.346335
0.245491
0.334242
0.258078
0.354176
0.321975
0.258085
0.295802
0.305626
0.387766
0.273724
0.290287
0.321441
0.327561
0.255026
0.290386
0.280339
0.255527
0.298302
0.306683
0.195443
0.358449
0.241336
0.373126
0.253557
0.214792
0.305674
0.280378
0.343197
0.366787
0.276942
0.268094
0.168371
0.245901
0.298778
0.259956
0.25666
0.287455
0.269845
0.218483
0.294437
0.295181
0.305536
0.206319
0.280013
0.322677
0.288346
0.241131
0.266195
0.264982
0.296257
0.3003
0.377406
0.26884
0.268009
0.287718
0.227764
0.327918
0.207449
0.336152
0.234233
0.30972
0.37544
0.203731
0.294301
0.313933
0.271753
0.389465
0.412929
0.304692
0.298007
0.343537
0.255937
0.154347
0.317867
0.249846
0.345872
0.246336
0.277842
0.240931
0.298713
0.311668
0.317681
0.308057
0.256132
0.230962
0.365042
0.326694
0.182882
0.298977
0.230151
0.319488
0.302716
0.337055
0.280472
0.330269
0.271947
0.253045
0.201525
0.212759
0.329366
0.2667
0.215311
0.318515
0.307372
0.270939
0.322351
0.326016
0.224039
0.200143
0.148069
0.322914
0.13252
0.305279
0.34209
0.192442
0.234607
0.224739
0.314081
0.216559
0.313695
0.287918
0.219836
0.280369
0.21938
0.330277
0.298507
0.229414
0.364292
0.321173
0.315297
0.247579
0.200411
0.318642
0.303608
0.204451
0.199351
0.328443
0.221468
0.234873
0.31292
0.257741
0.246582
0.175981
0.323016
0.331048
0.308028
0.357653
0.174243
0.24941
0.22787
0.237612
0.274157
0.266045
0.278884
0.30807
0.224181
0.21667
0.219935
0.215113
0.295749
0.228779
0.287847
0.274755
0.290104
0.257592
0.158541
0.307881
0.25223
0.246907
0.207552
0.256707
0.271685
0.263728
0.195208
0.297048
0.323941
0.256202
0.296758
0.304664
0.282952
0.325998
0.222861
0.298897
0.260353
0.22891
0.236029
0.267145
0.257044
0.235449
0.190918
0.327772
0.339791
0.271775
0.258712
0.315136
0.292156
0.223943
0.253757
0.249119
0.329328
0.240682
0.239011
0.330561
0.250688
0.275255
0.194366
0.217429
0.278699
0.363364
0.38697
0.241689
0.220995
0.258278
0.245144
0.284179
0.229927
0.172926
0.252574
0.335411
0.254785
0.348012
0.214253
0.278575
0.238369
0.257474
0.271734
0.243277
0.277504
0.341744
0.319304
0.317398
0.280298
0.340003
0.28844
0.218135
0.293611
0.336241
0.241568
0.354696
0.218075
0.284119
0.256733
0.215865
0.418464
0.153256
0.277758
0.34079
0.268386
0.207642
0.248794
0.214411
0.29601
0.309726
0.199095
0.320172
0.192783
0.197918
0.377143
0.2568
0.177411
0.351495
0.130222
0.340184
0.291682
0.290544
0.227707
0.278339
0.292904
0.369405
0.232014
0.243364
0.261748
0.273672
0.213081
0.235931
0.181272
0.289088
0.276575
0.174001
0.303271
0.187932
0.282802
0.239955
0.284414
0.243658
0.221988
0.258108
0.246219
0.289428
0.257696
0.229551
0.21741
0.328822
0.311971
0.254872
0.316906
0.231207
0.262648
0.306369
0.288493
0.271459
0.192031
0.303622
0.41162
0.223178
0.259723
0.214679
0.397215
0.174605
0.295229
0.268848
0.228744
0.264676
0.265323
0.2117
0.24944
0.200965
0.338666
0.355019
0.345025
0.30573
0.21673
0.330682
0.274113
0.192523
0.270616
0.289474
0.211751
0.252088
0.231713
0.332725
0.199363
0.250235
0.288874
0.22122
0.276812
0.265646
0.223562
0.262645
0.350757
0.265914
0.22279
0.334539
0.28883
0.286489
0.25148
0.410521
0.221324
0.322639
0.262011
0.209446
0.189491
0.23558
0.220028
0.192532
0.163899
0.271476
0.24814
0.338593
0.256927
0.251407
0.267802
0.290043
0.245387
0.262148
0.215823
0.261929
0.282015
0.211442
0.221292
0.340921
0.248831
0.277278
0.221039
0.216445
0.386351
0.212501
0.197268
0.285732
0.250364
0.285536
0.23757
0.269246
0.22951
0.183171
0.250872
0.30216
0.236106
0.358883
0.251053
0.266955
0.258212
0.367088
0.332073
0.323386
0.242664
0.226073
0.238923
0.304846
0.375678
0.330368
0.313987
0.19607
0.134187
0.268039
0.19912
0.301802
0.175763
0.146937
0.250498
0.286757
0.340791
0.19479
0.220702
0.417666
0.29547
0.255934
0.305307
0.237082
0.329802
0.265628
0.199395
0.243804
0.273119
0.235254
0.276254
0.28398
0.295666
0.294914
0.321386
0.246663
0.181925
0.201042
0.34126
0.22532
0.257131
0.330786
0.269258
0.269718
0.273251
0.141651
0.167173
0.223552
0.284899
0.218187
0.354566
0.252452
0.269409
0.173821
0.257801
0.196751
0.253583
0.203233
0.234229
0.197329
0.188659
0.255795
0.235756
0.165959
0.223946
0.227488
0.248884
0.187079
0.288895
0.260275
0.247191
0.234818
0.298696
0.159664
0.157798
0.226852
0.221296
0.301208
0.255157
0.219426
0.265181
0.242169
0.22737
0.368722
0.186773
0.174942
0.211064
0.225059
0.20583
0.266082
0.275465
0.209388
0.282925
0.28779
0.376804
0.199765
0.287445
0.224505
0.287785
0.218906
0.239334
0.296105
0.269542
0.254749
0.25984
0.206184
0.202108
0.232054
0.191722
0.287824
0.186813
0.223896
0.170904
0.329098
0.20976
0.16389
0.197081
0.21957
0.152862
0.266777
0.318929
0.284168
0.184058
0.189692
0.221393
0.146606
0.232144
0.267032
0.272488
0.224373
0.194534
0.319606
0.174593
0.180198
0.1856
0.209786
0.20547
0.227403
0.236377
0.235994
0.211703
0.263529
0.19659
0.279163
0.277781
0.176408
0.270348
0.24908
0.303667
0.176676
0.342215
0.293025
0.236686
0.178143
0.247409
0.198952
0.366561
0.280147
0.217366
0.281635
0.171967
0.226402
0.282493
0.279202
0.287876
0.179307
0.335104
0.227053
0.148146
0.280301
0.20616
0.247335
0.260018
0.292339
0.20556
0.277167
0.174411
0.385961
0.274886
0.271554
0.334994
0.249684
0.269627
0.279766
0.338049
0.215783
0.291618
0.236327
In [132]:
fig, ax = plt.subplots(figsize=(5, 5))

w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)
ax.scatter(*sess.run(w_posterior_samples).T, alpha=.6)
ax.scatter(*sess.run(w_prior_samples).T, alpha=.6)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [133]:
for step in range(3*200):

    g_loss, _ = sess.run([inference_loss, inference_train_op])

    for d_step in range(3*50):

        d_loss, _ = sess.run([discrim_loss, discrim_train_op])
    
    print(d_loss, g_loss)
0.299976 3.34463
0.42477 3.42557
0.489073 3.16171
0.492646 2.78832
0.697428 2.15677
0.733206 1.72106
0.733298 1.52407
0.807746 1.29283
1.01707 1.36824
0.982902 1.16122
0.972628 1.10115
1.03581 1.21175
1.08277 1.57636
1.06537 1.18987
1.08656 1.15926
1.12824 1.12963
1.11872 1.36252
1.20765 1.06292
1.25455 1.00747
1.29707 1.10839
1.29748 1.12028
1.30085 0.914041
1.32332 0.986965
1.29178 0.985552
1.32019 1.00764
1.2936 1.06811
1.31705 1.0777
1.34883 1.0747
1.25078 0.984631
1.27866 1.05978
1.26078 1.02381
1.23433 0.950675
1.17123 0.905802
1.20704 1.02897
1.17519 0.958528
1.18514 0.944483
1.16954 0.834288
1.2178 0.968593
1.16209 0.839961
1.20248 0.932302
1.14983 0.903876
1.22635 0.947336
1.14538 1.02064
1.18963 0.978661
1.19453 0.916686
1.30689 0.953143
1.23519 0.983169
1.28594 0.952351
1.26316 0.985327
1.25242 0.95212
1.25774 0.928886
1.26975 0.893379
1.27822 0.898999
1.28621 0.888295
1.33779 0.949368
1.22241 0.955898
1.33705 0.829938
1.27576 0.985328
1.32099 0.844981
1.31612 0.929675
1.27131 0.906424
1.31378 0.931417
1.24642 0.892234
1.30976 0.890492
1.25219 0.913798
1.28763 0.873917
1.29087 0.897982
1.30312 0.900786
1.2624 0.86226
1.29503 0.897659
1.34493 0.907091
1.20294 0.901368
1.29374 0.888826
1.27329 0.933204
1.2393 0.855003
1.26956 0.798132
1.24166 0.98339
1.33536 0.908256
1.28197 0.909913
1.26196 0.928738
1.25027 0.899723
1.29102 0.904741
1.2854 0.895078
1.28242 0.932583
1.29949 0.93923
1.33095 0.933411
1.31844 0.906776
1.23557 0.914326
1.29453 0.88582
1.3628 0.929388
1.29 0.983679
1.28356 0.910055
1.29529 0.888072
1.27999 0.870004
1.2683 0.93354
1.28719 0.855926
1.25781 0.97036
1.23279 0.88822
1.29746 0.913348
1.1776 0.91485
1.23394 0.862434
1.24866 0.926131
1.26843 0.904428
1.28506 0.90711
1.28996 0.919347
1.30133 0.936005
1.19388 0.895903
1.29447 0.848564
1.29656 0.866625
1.30646 0.876477
1.34333 0.915293
1.31501 0.877347
1.29116 0.859644
1.24887 0.882841
1.24631 0.885617
1.29015 0.925848
1.31109 0.83671
1.27925 0.923947
1.1967 0.858732
1.27543 0.923373
1.30359 0.922775
1.32006 0.861679
1.28861 0.888932
1.25018 0.960168
1.28214 0.921384
1.2733 0.898344
1.24917 0.900367
1.29188 0.848968
1.2743 0.8913
1.2666 0.882793
1.31015 0.882384
1.33708 0.880494
1.33403 0.938472
1.30036 0.907784
1.2895 0.927621
1.22617 0.863041
1.2021 0.889226
1.29513 0.906224
1.21781 0.956144
1.24994 0.862738
1.28891 0.918831
1.26892 0.928247
1.29732 0.844343
1.28297 0.850404
1.31501 0.943559
1.2537 0.913817
1.22275 0.911185
1.31151 0.952311
1.27463 0.908037
1.30624 0.904855
1.24646 0.885781
1.28867 0.914277
1.2849 0.840575
1.35278 0.877746
1.2319 0.902562
1.31323 0.889612
1.26586 0.925208
1.32814 0.939342
1.28865 0.910893
1.3173 0.886235
1.26333 0.865478
1.23717 0.871529
1.23409 0.912993
1.30309 0.883133
1.28162 0.910106
1.28846 0.863426
1.30377 0.886678
1.23869 0.881498
1.27749 0.914509
1.35021 0.887712
1.27151 0.858169
1.24585 0.906618
1.28383 0.90202
1.29657 0.860148
1.3807 0.929281
1.23629 0.867771
1.26817 0.904284
1.27373 0.864512
1.23917 0.927002
1.27068 0.849798
1.31069 0.869198
1.35202 0.943909
1.22629 0.941014
1.23198 0.916758
1.31284 0.904447
1.25665 0.888998
1.24618 0.903868
1.35746 0.883824
1.26001 0.897694
1.26402 0.917581
1.25175 0.873426
1.26214 0.896461
1.24735 0.927409
1.25836 0.897048
1.32878 0.865081
1.25151 0.932814
1.31282 0.862772
1.29125 0.91402
1.29157 0.902586
1.3243 0.909463
1.26532 0.932931
1.25602 0.898872
1.31726 0.904133
1.34997 0.858721
1.32501 0.921767
1.30467 0.876867
1.26928 0.900812
1.27455 0.934993
1.26909 0.892616
1.29086 0.923215
1.3279 0.877819
1.22691 0.930301
1.32907 0.901793
1.27293 0.907734
1.30581 0.918265
1.26162 0.919997
1.28483 0.867349
1.28488 0.883433
1.28077 0.870668
1.18876 0.903459
1.25426 0.862026
1.26798 0.928929
1.31472 0.854127
1.27993 0.87624
1.27464 0.867176
1.29654 0.882978
1.32485 0.886917
1.26158 0.855248
1.24716 0.862936
1.28117 0.931524
1.20632 0.885412
1.20615 0.882369
1.23783 0.876454
1.30756 0.904714
1.26956 0.883896
1.24094 0.843747
1.23795 0.893069
1.28202 0.874765
1.23506 0.918747
1.22337 0.83743
1.3601 0.871304
1.24585 0.913947
1.24939 0.869258
1.31665 0.911868
1.29636 0.910884
1.2456 0.8817
1.27877 0.869583
1.22918 0.859839
1.26446 0.887992
1.19153 0.922629
1.27731 0.885781
1.24352 0.887125
1.27908 0.859451
1.29932 0.897275
1.29007 0.875759
1.31984 0.866664
1.27071 0.921769
1.22873 0.906817
1.28178 0.893999
1.21822 0.902762
1.27645 0.882699
1.33552 0.942864
1.29019 0.888469
1.26938 0.886925
1.33267 0.913781
1.23469 0.91071
1.24169 0.893096
1.27127 0.883584
1.27609 0.910159
1.24287 0.899395
1.27423 0.86692
1.26439 0.881307
1.27183 0.913777
1.32222 0.900624
1.30228 0.887633
1.3534 0.900266
1.28201 0.926251
1.24125 0.885375
1.27778 0.876981
1.34469 0.882942
1.20779 0.894412
1.29135 0.91812
1.36108 0.915039
1.25821 0.892367
1.27303 0.894492
1.27273 0.918805
1.35934 0.885502
1.22958 0.895696
1.31604 0.914404
1.30014 0.929029
1.32877 0.914306
1.2374 0.896187
1.29922 0.923526
1.31934 0.897103
1.25371 0.876657
1.27429 0.917426
1.19264 0.870283
1.21791 0.897719
1.28827 0.932747
1.29979 0.929157
1.34422 0.905238
1.27498 0.868193
1.32279 0.860391
1.30463 0.896259
1.28341 0.890942
1.26142 0.901491
1.29638 0.877999
1.2651 0.902819
1.31532 0.891434
1.26534 0.876294
1.27595 0.91941
1.29804 0.895347
1.35388 0.915468
1.30606 0.875761
1.2592 0.897856
1.28771 0.890425
1.25574 0.868415
1.30032 0.880557
1.24677 0.858536
1.32066 0.903354
1.34579 0.91493
1.21861 0.87927
1.27466 0.913171
1.26265 0.902121
1.30983 0.886372
1.25601 0.891181
1.21984 0.966641
1.27002 0.888857
1.33133 0.892649
1.2784 0.876429
1.34413 0.91359
1.28193 0.873876
1.33145 0.912212
1.33636 0.882785
1.18631 0.890936
1.29176 0.884748
1.30906 0.92735
1.27363 0.912045
1.29613 0.901083
1.27562 0.916263
1.2322 0.904348
1.33516 0.925911
1.28743 0.900918
1.2766 0.883029
1.34278 0.876535
1.20862 0.903666
1.25352 0.896885
1.2914 0.885103
1.22056 0.905637
1.27943 0.886183
1.29659 0.85713
1.33158 0.885512
1.25268 0.908388
1.27947 0.882604
1.25241 0.876734
1.26927 0.915657
1.28884 0.88941
1.21869 0.866007
1.28437 0.884098
1.28296 0.896861
1.31109 0.905855
1.29219 0.926
1.23675 0.880003
1.30739 0.915941
1.2876 0.87587
1.30799 0.907162
1.30787 0.900957
1.32889 0.891528
1.27516 0.868868
1.30612 0.89073
1.35053 0.907888
1.22153 0.884319
1.33086 0.917357
1.25139 0.87027
1.22702 0.930292
1.32561 0.885983
1.31254 0.894593
1.2653 0.877631
1.22026 0.879771
1.29784 0.907676
1.25376 0.895189
1.20714 0.888735
1.28772 0.915645
1.28619 0.890366
1.259 0.891858
1.28474 0.922514
1.25909 0.904846
1.26219 0.926151
1.32035 0.884415
1.32279 0.862578
1.28065 0.870504
1.2479 0.858193
1.34975 0.883737
1.21143 0.903801
1.29584 0.885905
1.27986 0.911445
1.27521 0.890818
1.28693 0.885505
1.31283 0.925045
1.24759 0.862306
1.312 0.927847
1.28998 0.9144
1.36809 0.887049
1.25219 0.905859
1.30386 0.887479
1.23131 0.900594
1.27164 0.902711
1.26537 0.874508
1.34274 0.875699
1.20408 0.878568
1.3324 0.880736
1.34016 0.910157
1.23586 0.88982
1.28405 0.904918
1.23743 0.891502
1.35555 0.873177
1.298 0.901068
1.23655 0.893607
1.27821 0.906993
1.24965 0.881674
1.26635 0.885439
1.24883 0.899231
1.24795 0.891216
1.28834 0.917066
1.31616 0.890811
1.25978 0.893421
1.19358 0.890076
1.28431 0.895919
1.22137 0.898382
1.27667 0.871932
1.31754 0.889699
1.27078 0.913125
1.34449 0.884558
1.31471 0.906929
1.27722 0.899061
1.29322 0.912265
1.29251 0.892235
1.3013 0.857401
1.30754 0.883708
1.321 0.920577
1.2749 0.909487
1.30773 0.89889
1.23442 0.907068
1.33218 0.889464
1.29985 0.897404
1.2953 0.906924
1.25735 0.877665
1.27992 0.901778
1.35627 0.923606
1.31298 0.887356
1.30567 0.907428
1.24799 0.886464
1.24527 0.89604
1.24649 0.867331
1.23434 0.866354
1.23262 0.906688
1.26906 0.864481
1.31386 0.880244
1.23446 0.875825
1.30225 0.871733
1.28758 0.942493
1.27359 0.8746
1.21847 0.914368
1.2833 0.876904
1.26146 0.869722
1.33898 0.910604
1.29383 0.893281
1.2845 0.88963
1.29536 0.875872
1.32227 0.892406
1.25826 0.920493
1.25819 0.906301
1.25176 0.888117
1.26221 0.904246
1.26325 0.892346
1.28616 0.872822
1.31126 0.8707
1.24719 0.897971
1.28051 0.914498
1.2817 0.879494
1.41589 0.886121
1.31908 0.909739
1.29879 0.886162
1.29675 0.896006
1.31087 0.875572
1.21402 0.907335
1.28895 0.904906
1.2114 0.895427
1.2151 0.902807
1.30628 0.880755
1.23686 0.912286
1.32852 0.892941
1.25243 0.904918
1.28739 0.89501
1.37896 0.883422
1.34163 0.900157
1.28176 0.91432
1.28685 0.929243
1.24621 0.903233
1.335 0.856527
1.28115 0.875209
1.27907 0.860913
1.3647 0.907754
1.27064 0.892142
1.30152 0.840292
1.25153 0.852177
1.29679 0.869337
1.27773 0.906446
1.27954 0.891281
1.23566 0.907888
1.33123 0.878992
1.29079 0.891348
1.39201 0.900257
1.29935 0.898852
1.29303 0.917898
1.24376 0.871551
1.36831 0.886407
1.25888 0.899488
1.2399 0.867973
1.30075 0.906447
1.26663 0.894795
1.28937 0.896779
1.28319 0.915181
1.2914 0.903844
1.22604 0.892134
1.28493 0.865643
1.25693 0.915223
1.17788 0.91207
1.31807 0.887731
1.31806 0.877444
1.33148 0.913566
1.31501 0.895104
1.21561 0.900399
1.32397 0.882778
1.2412 0.887779
1.28798 0.901722
1.32204 0.891469
1.30445 0.878949
1.3195 0.893537
1.29647 0.893651
1.30864 0.90272
1.33842 0.906772
1.27186 0.884606
1.25376 0.912435
1.24348 0.88483
1.32806 0.885722
1.26289 0.875498
1.34817 0.930449
1.2741 0.836675
1.31649 0.894658
1.28392 0.890792
1.3196 0.912111
1.32396 0.894139
1.26697 0.911493
1.23262 0.897811
1.22736 0.907085
1.22682 0.897704
1.23601 0.892503
1.21706 0.906444
1.2371 0.883
1.31262 0.876441
1.2599 0.89583
1.22319 0.91635
1.32081 0.877347
1.25163 0.909003
1.2904 0.878847
1.27256 0.882984
1.22443 0.863475
1.24356 0.915976
1.25261 0.886835
1.26352 0.875685
1.34715 0.879681
1.32883 0.882916
1.23169 0.914487
1.29897 0.870389
1.2282 0.893458
1.33507 0.866107
1.30039 0.866845
1.25275 0.871158
1.24227 0.896042
1.33085 0.8939
1.22338 0.893557
1.26838 0.905916
1.26894 0.895362
1.3019 0.881896
1.2906 0.887678
1.33256 0.882188
1.24752 0.90323
1.2359 0.907687
1.26341 0.883272
1.31385 0.899933
1.30353 0.852505
1.24364 0.898424
1.28896 0.898842
1.27565 0.90108
1.2543 0.879171
1.24601 0.927591
1.27888 0.894137
1.30922 0.911365
1.25757 0.857223
In [134]:
fig, ax = plt.subplots(figsize=(5, 5))

w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)
ax.scatter(*sess.run(w_posterior_samples).T, alpha=.8)
ax.scatter(*sess.run(w_prior_samples).T, alpha=.8)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [138]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))

sns.kdeplot(*sess.run(w_posterior_samples).T, cmap='magma', ax=ax2)

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

ax1.contourf(w1, w2, 
             np.exp(log_prior+np.sum(llhs, axis=2)), 
             cmap=plt.cm.magma)

ax1.scatter(*sess.run(w_posterior_samples).T, alpha=.6)

ax1.set_xlabel('$w_1$')
ax1.set_ylabel('$w_2$')

ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)

plt.show()
In [ ]:
 
In [ ]:
 

Visualizing the Latent Space of Vector Drawings from the Google QuickDraw Dataset with SketchRNN, PCA and t-SNE

t-SNE Visualization of Sheep Sketches

This is the third part in a series of notes on my exploration of the recently released Google QuickDraw dataset 1, using the concurrently released SketchRNN model.

The QuickDraw dataset is curated from the millions of drawings contributed by over 15 million people around the world who participated in the "Quick, Draw!" A.I. Experiment, in which they were given the challenge of drawing objects belonging to a particular class (such as "cat") in under 20 seconds.

SketchRNN is an impressive generative model that was trained to produce vector drawings using this dataset. It was of particular interest to me because it cleverly assembles many of the latest tools and techniques recently developed in machine learning, such as Variational Autoencoders, HyperLSTMs (a HyperNetwork for LSTM), Autoregressive models, Layer Normalization, Recurrent Dropout, the Adam optimizer, among others.

Read more…

Exploring the Google QuickDraw Dataset with SketchRNN (Part 3)

t-SNE Visualization of Sheep Sketches

This is the third part in a series of notes on my exploration of the recently released Google QuickDraw dataset 1, using the concurrently released SketchRNN model.

The QuickDraw dataset is curated from the millions of drawings contributed by over 15 million people around the world who participated in the "Quick, Draw!" A.I. Experiment, in which they were given the challenge of drawing objects belonging to a particular class (such as "cat") in under 20 seconds.

SketchRNN is an impressive generative model that was trained to produce vector drawings using this dataset. It was of particular interest to me because it cleverly assembles many of the latest tools and techniques recently developed in machine learning, such as Variational Autoencoders, HyperLSTMs (a HyperNetwork for LSTM), Autoregressive models, Layer Normalization, Recurrent Dropout, the Adam optimizer, among others.

Read more…

Exploring the Google QuickDraw Dataset with SketchRNN (Part 2)

This is the second part in a series of notes on my exploration of the recently released Google QuickDraw dataset, using the concurrently released SketchRNN model.

In the previous note, we set up our development environment, downloaded a subset of the data along with some pre-trained models, and developed some utilities for visualizing the data in the notebook. We retain most of the code from previous note and omit the expository code and markdown cells.


The QuickDraw dataset is curated from the millions of drawings contributed by over 15 million people around the world who participated in the "Quick, Draw!" A.I. Experiment, in which they were given the challenge of drawing objects belonging to a particular class (such as "cat") in under 20 seconds.

SketchRNN is a very impressive generative model that was trained to produce vector drawings using this dataset. It was of particular interest to me because it cleverly combines many of the latest tools and techniques recently developed in machine learning, such as Variational Autoencoders, HyperLSTMs (a HyperNetwork for LSTM), Autoregressive models, Layer Normalization, Recurrent Dropout, the Adam optimizer, and others.

In [48]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In [49]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import numpy as np
import tensorflow as tf

from matplotlib.animation import FuncAnimation
from matplotlib.path import Path
from matplotlib import rc

from six.moves import map
In [50]:
from magenta.models.sketch_rnn.sketch_rnn_train import \
    (load_env,
     load_checkpoint,
     reset_graph,
     download_pretrained_models,
     PRETRAINED_MODELS_URL)
from magenta.models.sketch_rnn.model import Model, sample
from magenta.models.sketch_rnn.utils import (lerp,
                                             slerp,
                                             get_bounds, 
                                             to_big_strokes,
                                             to_normal_strokes)
In [52]:
# For inine display of animation
# equivalent to rcParams['animation.html'] = 'html5'
rc('animation', html='html5')
In [53]:
# set numpy output to something sensible
np.set_printoptions(precision=8, 
                    edgeitems=6, 
                    linewidth=200, 
                    suppress=True)
In [54]:
tf.logging.info("TensorFlow Version: {}".format(tf.__version__))
INFO:tensorflow:TensorFlow Version: 1.1.0

Getting the Pre-Trained Models and Data

In [7]:
DATA_DIR = ('http://github.com/hardmaru/sketch-rnn-datasets/'
            'raw/master/aaron_sheep/')
MODELS_ROOT_DIR = '/tmp/sketch_rnn/models'
In [8]:
DATA_DIR
Out[8]:
'http://github.com/hardmaru/sketch-rnn-datasets/raw/master/aaron_sheep/'
In [9]:
PRETRAINED_MODELS_URL
Out[9]:
'http://download.magenta.tensorflow.org/models/sketch_rnn.zip'
In [10]:
download_pretrained_models(
    models_root_dir=MODELS_ROOT_DIR,
    pretrained_models_url=PRETRAINED_MODELS_URL)
INFO:tensorflow:/tmp/sketch_rnn/models/sketch_rnn.zip already exists, using cached copy
INFO:tensorflow:Unzipping /tmp/sketch_rnn/models/sketch_rnn.zip...
INFO:tensorflow:Unzipping complete.

We look at the layer normalized model trained on the aaron_sheep dataset for now.

In [11]:
MODEL_DIR = MODELS_ROOT_DIR + '/aaron_sheep/layer_norm'
In [12]:
(train_set, 
 valid_set, 
 test_set, 
 hps_model, 
 eval_hps_model, 
 sample_hps_model) = load_env(DATA_DIR, MODEL_DIR)
INFO:tensorflow:Downloading http://github.com/hardmaru/sketch-rnn-datasets/raw/master/aaron_sheep/aaron_sheep.npz
INFO:tensorflow:Loaded 7400/300/300 from aaron_sheep.npz
INFO:tensorflow:Dataset combined: 8000 (7400/300/300), avg len 125
INFO:tensorflow:model_params.max_seq_len 250.
total images <= max_seq_len is 7400
total images <= max_seq_len is 300
total images <= max_seq_len is 300
INFO:tensorflow:normalizing_scale_factor 18.5198.
In [222]:
class SketchPath(Path):
    
    def __init__(self, data, factor=.2, *args, **kwargs):
        
        vertices = np.cumsum(data[::, :-1], axis=0) / factor
        codes = np.roll(self.to_code(data[::,-1].astype(int)), 
                        shift=1)
        codes[0] = Path.MOVETO

        super(SketchPath, self).__init__(vertices, 
                                         codes, 
                                         *args, 
                                         **kwargs)
        
    @staticmethod
    def to_code(cmd):
        # if cmd == 0, the code is LINETO
        # if cmd == 1, the code is MOVETO (which is LINETO - 1)
        return Path.LINETO - cmd
In [69]:
def draw(sketch_data, factor=.2, pad=(10, 10), ax=None):

    if ax is None:
        ax = plt.gca()

    x_pad, y_pad = pad
    
    x_pad //= 2
    y_pad //= 2
        
    x_min, x_max, y_min, y_max = get_bounds(data=sketch_data,
                                            factor=factor)

    ax.set_xlim(x_min-x_pad, x_max+x_pad)
    ax.set_ylim(y_max+y_pad, y_min-y_pad)

    sketch = SketchPath(sketch_data)

    patch = patches.PathPatch(sketch, facecolor='none')
    ax.add_patch(patch)

The real fun begins

Everything up to here has more or less been copied straight from the previous notebook. Now we load the pre-trained SketchRNN model and use it to begin our exploration of the test dataset.

In [110]:
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = False.
INFO:tensorflow:Output dropout mode = False.
INFO:tensorflow:Recurrent dropout mode = True.
INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = False.
INFO:tensorflow:Output dropout mode = False.
INFO:tensorflow:Recurrent dropout mode = False.
INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = False.
INFO:tensorflow:Output dropout mode = False.
INFO:tensorflow:Recurrent dropout mode = False.
In [111]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
In [112]:
# loads the weights from checkpoint into our model
load_checkpoint(sess=sess, checkpoint_path=MODEL_DIR)
INFO:tensorflow:Loading model /tmp/sketch_rnn/models/aaron_sheep/layer_norm/vector.
INFO:tensorflow:Restoring parameters from /tmp/sketch_rnn/models/aaron_sheep/layer_norm/vector

The helper functions for encoding a sketch to some latent code $z$ and then decoding it back to a sketch were provided in the original notebook. I just made some minor syntactic changes and removed the behaviour of plotting as a side-effect.

In [73]:
def encode(input_strokes):
    strokes = to_big_strokes(input_strokes).tolist()
    strokes.insert(0, [0, 0, 1, 0, 0])
    seq_len = [len(input_strokes)]
    z = sess.run(eval_model.batch_z,
                 feed_dict={
                    eval_model.input_data: [strokes], 
                    eval_model.sequence_lengths: seq_len})[0]
    return z
In [74]:
def decode(z_input=None, temperature=.1, factor=.2):
    z = None
    if z_input is not None:
        z = [z_input]
    sample_strokes, m = sample(
        sess, 
        sample_model, 
        seq_len=eval_model.hps.max_seq_len, 
        temperature=temperature, z=z)
    return to_normal_strokes(sample_strokes)

Now we get a random sample from the test dataset

In [120]:
sketch = test_set.random_sample()
In [138]:
fig, ax = plt.subplots(figsize=(3, 3),
                       subplot_kw=dict(xticks=[], 
                                       yticks=[], 
                                       frame_on=False))

draw(sketch, ax=ax)

plt.show()

We project it into the 128-dimensional latent space using the pre-trained encoder

In [183]:
z = encode(sketch)
z.shape
Out[183]:
(128,)

Now we can reconstruct the original sketch from the learned latent representation using the pre-trained decoder, with temperature $\tau=0.8$. The temperature parameter controls the level of randomness in the samples generated by the model, which becomes deterministic as $\tau \to 0$, and produces samples that are the most likely point in the probability density function. See pg. 7 of the original paper for further discussion of the effects the temperature parameter has on the sampling process.

In [216]:
sketch_reconstructed = decode(z, temperature=.6)
sketch_reconstructed.shape
Out[216]:
(250, 3)
In [217]:
fig, ax = plt.subplots(figsize=(3, 3),
                       subplot_kw=dict(xticks=[], 
                                       yticks=[], 
                                       frame_on=False))

draw(sketch_reconstructed, ax=ax)

plt.show()

Variance in the Reconstruction

The grid of drawings below consists of samples of the reconstructed drawings at various settings of the temperature parameter. The first column is the original drawing, and each of the remaining columns are 5 samples of the reconstructed drawing with $\tau$ increasing from 0.1 to 0.9.

In [182]:
fig, ax_arr = plt.subplots(nrows=5, 
                           ncols=10, 
                           figsize=(8, 4),
                           subplot_kw=dict(xticks=[],
                                           yticks=[],
                                           frame_on=False))
fig.tight_layout()

for row_num, ax_row in enumerate(ax_arr):    
    for col_num, ax in enumerate(ax_row):
        if not col_num:
            draw(sketch, ax=ax)
            xlabel = 'original'
        else:
            t = col_num / 10.
            draw(decode(z, temperature=t), ax=ax)
            xlabel = r'$\tau={}$'.format(t)
        if row_num+1 == len(ax_arr):
            ax.set_xlabel(xlabel)

plt.show()

At the lowest setting of the temperature at $\tau=0.1$, we see the samples consistently share a similar appearance - they all look like vertical strokes emanating from a fluffy cloud. However, they are also consistently dissimilar to the original sketch. In this sense, the samples from the models seems to exhibit high bias and low variance. As we increase the variance in the samples by increasing $\tau$, we start to find some samples that resemble our original sketch. But when we increase $\tau$ a little too much, beyond say 0.8, we begin to see a little too much randomness in the samples.

Drawing Comparisons

Humans typically write and, by extension, draw from left to right, top to bottom. Here, I wanted to animate the process of the original sketch being drawn alongside the decoder's reconstruction of the sketch to compare stroke patterns, typical stroke lengths, etc.

In [218]:
fig, (ax1, ax2) = plt.subplots(ncols=2, nrows=1, figsize=(6, 3),
                               subplot_kw=dict(xticks=[], 
                                               yticks=[]))
fig.tight_layout()

x_pad, y_pad = 10, 10
    
x_pad //= 2
y_pad //= 2

(x_min_1, 
 x_max_1,
 y_min_1,
 y_max_1) = get_bounds(data=sketch, factor=.2)

(x_min_2, 
 x_max_2, 
 y_min_2, 
 y_max_2) = get_bounds(data=sketch_reconstructed, factor=.2)

x_min = np.minimum(x_min_1, x_min_2)
y_min = np.minimum(y_min_1, y_min_2)

x_max = np.maximum(x_max_1, x_max_2)
y_max = np.maximum(y_max_1, y_max_2)

ax1.set_xlim(x_min-x_pad, x_max+x_pad)
ax1.set_ylim(y_max+y_pad, y_min-y_pad)

ax1.set_xlabel('Original')

ax2.set_xlim(x_min-x_pad, x_max+x_pad)
ax2.set_ylim(y_max+y_pad, y_min-y_pad)

ax2.set_xlabel('Reconstruction')
Out[218]:
<matplotlib.text.Text at 0x7f232a2097d0>
In [219]:
def animate(i):

    original = SketchPath(sketch[:i+1])
    reconstructed = SketchPath(sketch_reconstructed[:i+1])

    patch1 = ax1.add_patch(patches.PathPatch(original,
                                             facecolor='none'))

    patch2 = ax2.add_patch(patches.PathPatch(reconstructed, 
                                             facecolor='none'))
    
    return patch1, patch2
In [220]:
frames = np.maximum(sketch.shape[0], 
                    sketch_reconstructed.shape[0])
frames
Out[220]:
249
In [221]:
FuncAnimation(fig,
              animate,
              frames=frames-1, 
              interval=15,
              repeat_delay=1000*3, 
              blit=True)
Out[221]:

Unfortunately, the strokes that make up a sketch have been normalized with the Ramer–Douglas–Peucker algorithm, which is a simple stroke simplification process. This means the strokes aren't quite the same as the that which the human originally used to construct the sketch. Moreover, the timing of each stroke are also important to understanding patterns in how humans draw quick sketches. While timestamp data is provided in the full QuickDraw dataset, they are not preserved in the modified version of the dataset used by SketchRNN.